
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from nets.dcgan import discriminator, generator
from utils.utils import get_lr_scheduler, set_optimizer_lr
from utils.utils_fit import fit_one_epoch

if __name__ == "__main__":
    #卷积通道数的设置
    channel         = 64
    #图像大小的设置，如[128, 128]
    input_shape     = [64,64]

    #训练参数设置
    Init_Epoch      = 0
    Epoch           = 500
    batch_size      = 1

    #Init_lr 模型的最大学习率

    Init_lr             = 2e-3
    # Min_lr  模型的最小学习率，默认为最大学习率的0.01
    Min_lr              = Init_lr * 0.01

    # adam优化器
    optimizer_type      = "adam"
    momentum            = 0.5
    weight_decay        = 0

    #   lr_decay_type   使用到的学习率下降方式
    lr_decay_type       = "cos"

    #------------------------------------------------------------------#
    #   save_dir        权值与日志文件保存的文件夹
    #------------------------------------------------------------------#
    save_dir            = 'logs'
    #------------------------------------------------------------------#
    #   num_workers     用于设置是否使用多线程读取数据
    #                   开启后会加快数据读取速度，但是会占用更多内存
    #                   内存较小的电脑可以设置为2或者0
    #------------------------------------------------------------------#
    num_workers         = 0
    #------------------------------#
    #   每隔50个step保存一次图片
    #------------------------------#
    photo_save_step     = 50

    #------------------------------------------#
    #   获得图片路径
    #------------------------------------------#
    annotation_path = "train_lines.txt"

    #------------------------------------------------------#
    #   设置用到的显卡
    #------------------------------------------------------#
    ngpus_per_node  = torch.cuda.device_count()
    device          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    local_rank      = 0

    #------------------------------------------#
    #   生成网络和评价网络
    #------------------------------------------#
    G_model = generator(channel, input_shape)
    D_model = discriminator(channel, input_shape)

    #------------------------------------------#
    #   将训练好的模型重新载入
    #------------------------------------------#

    #----------------------#
    #   获得损失函数
    #----------------------#
    BCE_loss = nn.BCEWithLogitsLoss()


    scaler = None

    G_model_train = G_model.train()
    D_model_train = D_model.train()

    cudnn.benchmark = True
    G_model_train = torch.nn.DataParallel(G_model)
    G_model_train = G_model_train.cuda()
    D_model_train = torch.nn.DataParallel(D_model)
    D_model_train = D_model_train.cuda()

    with open(annotation_path) as f:
        lines = f.readlines()
    num_train = len(lines)

    #------------------------------------------------------#
    #   Init_Epoch为起始世代
    #   Epoch总训练世代
    #------------------------------------------------------#
    if True:
        #---------------------------------------#
        #   根据optimizer_type选择优化器
        #---------------------------------------#
        G_optimizer = {
            'adam'  : optim.Adam(G_model_train.parameters(), lr=Init_lr, betas=(momentum, 0.999), weight_decay = weight_decay),
            'sgd'   : optim.SGD(G_model_train.parameters(), Init_lr, momentum = momentum, nesterov=True)
        }[optimizer_type]

        D_optimizer = {
            'adam'  : optim.Adam(D_model_train.parameters(), lr=Init_lr, betas=(momentum, 0.999), weight_decay = weight_decay),
            'sgd'   : optim.SGD(D_model_train.parameters(), Init_lr, momentum = momentum, nesterov=True)
        }[optimizer_type]

        #   获得学习率下降的公式
        lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr, Min_lr, Epoch)

        #   判断每一个世代的长度
        epoch_step      = num_train // batch_size


        #   开始模型训练-#
        for epoch in range(Init_Epoch, Epoch):
            set_optimizer_lr(G_optimizer, lr_scheduler_func, epoch)
            set_optimizer_lr(D_optimizer, lr_scheduler_func, epoch)

            fit_one_epoch(G_model_train, D_model_train, G_model, D_model, G_optimizer, D_optimizer, BCE_loss,
                        epoch, epoch_step, gen, Epoch, scaler,  save_dir, photo_save_step, local_rank)
